{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "\n",
    "from src.utils.Config import Config\n",
    "from src.utils.Logging import Logger\n",
    "from src.components.memory import ReplayBuffer\n",
    "from src.networks.models import QNetwork\n",
    "\n",
    "from src.utils.misc import train, watch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Agent Configuration:\n",
      "env: \t\tEnvSpec(CartPole-v1)\n",
      "win condition: \t195.0\n",
      "device: \tcpu\n",
      "seed: \t\t123456789\n",
      "n_episodes: \t2000\n",
      "max_t: \t\t1000\n",
      "eps_start: \t1.0\n",
      "eps_end: \t0.01\n",
      "eps_decay: \t0.995\n",
      "eps_greedy: \tTrue\n",
      "noisy: \t\tFalse\n",
      "tau: \t\t0.001\n",
      "gamma: \t\t0.99\n",
      "lr: \t\t0.0005\n",
      "memory: \t<class 'src.components.memory.ReplayBuffer'>\n",
      "batch_size: \t64\n",
      "buffer_size: \t100000\n",
      "lr_annealing: \tFalse\n",
      "learn_every: \t4\n",
      "double_dqn: \tFalse\n",
      "model: \t\t<class 'src.networks.models.QNetwork'>\n",
      "save_loc: \tNone\n",
      "<_sre.SRE_Match object; span=(0, 20), match='EnvSpec(CartPole-v1)'>\n",
      "Logging at: logs/CartPole-v1/experiment-2020-04-23_08_32_15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/gym/envs/registration.py:14: PkgResourcesDeprecationWarning: Parameters to load are deprecated.  Call .resolve and .require separately.\n",
      "  result = entry_point.load(False)\n"
     ]
    }
   ],
   "source": [
    "config = Config()\n",
    "\n",
    "config.env = gym.make('CartPole-v1')\n",
    "\n",
    "config.win_condition = 195.0\n",
    "config.memory = ReplayBuffer\n",
    "config.model = QNetwork\n",
    "config.print_config()\n",
    "\n",
    "logger = Logger(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epi: 100\t Frame: 2302\tAverage Score: 23.0200\tMean: 17.0000\tDuration: 0.03\t#t_s: 16.0\n",
      "Epi: 200\t Frame: 3850\tAverage Score: 15.4800\tMean: 11.0000\tDuration: 0.02\t#t_s: 10.0\n",
      "Epi: 300\t Frame: 5388\tAverage Score: 15.3800\tMean: 10.0000\tDuration: 0.02\t#t_s: 9.0\n",
      "Epi: 400\t Frame: 6702\tAverage Score: 13.1400\tMean: 9.0000\tDuration: 0.02\t#t_s: 8.0\n",
      "Epi: 500\t Frame: 8038\tAverage Score: 13.3600\tMean: 10.0000\tDuration: 0.03\t#t_s: 9.0\n",
      "Epi: 600\t Frame: 9962\tAverage Score: 19.2400\tMean: 14.0000\tDuration: 0.05\t#t_s: 13.0\n",
      "Epi: 700\t Frame: 15778\tAverage Score: 58.1600\tMean: 239.0000\tDuration: 0.43\t#t_s: 238.0\n",
      "Epi: 764\t Frame: 31877 \tAverage: 195.9400\tMean: 238.0000\tDuration: 0.49\t#t_s: 237.0\n",
      "Environment Solved in 68.0009 seconds !\n"
     ]
    }
   ],
   "source": [
    "train(config, logger)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(logger.score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "watch(config, logger.log_file_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
